import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import Ridge


class reweightEG():

    def __init__(self, alpha=1, idx_group=None, n_group=None, crit=5 * 10 ** -4, n_iter=10 ** 6, verbose=0):

        self.coef = None
        self.idx = None
        self.alpha = alpha
        self.idx_group = idx_group
        self.n_group = n_group
        self.loss_func = None
        #alpharit coef idx alpha
        self.alpharit = crit
        self.n_iter = n_iter
        self.verbose = verbose
        self.converged = False

        if n_iter < 1:
            raise ValueError('At lease one iteration is required.')

        if idx_group is None and n_group is None:
            raise KeyError('n_group must be specified if idx_group is None.')

    def _compute_G(self, w, feat_group):
        # 将多维数组降位一维 行主序平铺
        w = np.ravel(w)
        # 组数
        n_group = len(self.idx_group_new)
        n_feature = w.shape[0]

        G_diag = np.zeros(n_feature)
        w_group_norm = np.empty(n_group)
        for group_counter in range(n_group):
            w_group = w[self.idx_group_new[group_counter]]

            w_group_norm[group_counter] = np.linalg.norm(w_group, ord=1)
        # 组正则化项 一范数
        w_group_norm[np.where(w_group_norm == 0)[0]] = 10 ** -9

        w_abs = np.abs(w)
        # absolutely
        for feature_counter in range(n_feature):
            G_diag[feature_counter] = np.sqrt(w_abs[feature_counter] / w_group_norm[feat_group[feature_counter]])

        return G_diag

    def _compute_X_tran(self, X, G_diag):
        return np.dot(X, np.diag(G_diag))
        #diag 取对角元素
    def _compute_w_tran(self, X_tran, y):
        #返回更新的w参数
        w = 0
        if self.loss_func == 'hinge':
            clf = LinearSVC(fit_intercept=False, C=self.alpha)
            clf.fit(X_tran, y)
            w = clf.coef_
        elif self.loss_func == 'square':

            clf = Ridge(alpha=self.alpha, fit_intercept=False, tol=10 ** -9)
            clf.fit(X_tran, y)
            w = clf.coef_

        return np.ravel(w)

    def _create_rand_group(self, n_feature):
        self.idx_group = np.zeros((self.n_group, n_feature))
        idx = np.random.permutation(n_feature)
        # 随机排列一个序列，返回一个排列的序列。
        idx = np.array_split(idx, self.n_group)
        # 可以用于把narray分成几份。
        for sub_counter, sub_idx in enumerate(idx):
            self.idx_group[sub_counter, sub_idx] = 1

    def _l12_norm(self, X, y):
        if len(X.shape)==1:
            n_sample= X.shape[0]
            n_feature = 1
        else:
            n_sample,n_feature = X.shape

        if len(np.unique(y)) == 2:
            self.loss_func = 'hinge'
        else:
            self.loss_func = 'square'
        #2分类与多分类损失
        if self.idx_group is None:
            self._create_rand_group(n_feature)

        self.idx_group_new = []
        feat_group = {}
        for group_counter in range(self.idx_group.shape[0]):
            # 用于得到数组array中非零元素的位置（数组索引）
            temp = np.nonzero(self.idx_group[group_counter, :])[0]
            self.idx_group_new.append(temp)
            for idx_feature in temp:
                feat_group[idx_feature] = group_counter

        w = np.ones(n_feature) / n_feature
        G_diag = self._compute_G(w, feat_group)
        X_tran = self._compute_X_tran(X, G_diag)
        w_tran = self._compute_w_tran(X_tran, y)

        counter = 0
        while True:
            counter += 1

            w_pre = w.copy()
            w = np.multiply(w_tran, G_diag)

            G_diag = self._compute_G(w, feat_group)
            X_tran = self._compute_X_tran(X, G_diag)
            w_tran = self._compute_w_tran(X_tran, y)

            temp = np.linalg.norm(w_pre - w)
            #正则化l2
            if self.verbose == 1:
                print('iteration: %d, criteria: %.4f.' % (counter, temp))

            if temp <= self.alpharit or counter >= self.n_iter:
                break

        self.coef = w
        self.idx = np.where(np.abs(w) > 10 ** -3)[0]
        self.coef[np.where(np.abs(w) <= 10 ** -3)] = 0

        if counter < self.n_iter:
            self.converged = True

    def fit(self, X, y):
        self._l12_norm(X, y)

    def predict(self, X):
        if self.loss_func == 'hinge':
            return np.ravel(np.sign(np.dot(X, self.coef)))
        #dot向量点积与矩阵乘法 ravel变为一维数组
        else:
            return np.ravel(np.dot(X, self.coef))